{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "7c76a62f",
   "metadata": {},
   "source": [
    "# GPU check\n",
    "\n",
    "## Check NVIDIA driver"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5c993d00",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:50.855038Z",
     "start_time": "2023-04-13T17:18:49.989079Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Thu May 18 03:47:56 2023       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 520.61.05    Driver Version: 520.61.05    CUDA Version: 11.8     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  Tesla T4            Off  | 00000000:00:07.0 Off |                    0 |\n",
      "| N/A   45C    P0    27W /  70W |      2MiB / 15360MiB |      6%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "|  No running processes found                                                 |\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "6e098662",
   "metadata": {},
   "source": [
    "## Check GPU in PyTorch\n",
    "### Import PyTorch to check GPU devices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7bfa0d38",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:51.794978Z",
     "start_time": "2023-04-13T17:18:50.852002Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "96268b62",
   "metadata": {},
   "source": [
    "### Check the version of PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e2f00232",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:51.827270Z",
     "start_time": "2023-04-13T17:18:51.799271Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.0.0+cu117'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.__version__"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "248cc2d0",
   "metadata": {},
   "source": [
    "### Check if PyTorch can call GPUs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "84e571e1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:51.890823Z",
     "start_time": "2023-04-13T17:18:51.824271Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cuda.is_available()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "46e03918",
   "metadata": {},
   "source": [
    "### Check the number of the GPUs of the computer in PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b37ff8de",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:51.897822Z",
     "start_time": "2023-04-13T17:18:51.854727Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n"
     ]
    }
   ],
   "source": [
    "gpu_num = torch.cuda.device_count()\n",
    "print(gpu_num)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ea613451",
   "metadata": {},
   "source": [
    "### Check the GPU type of the computer in PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b8b9a570",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:51.897822Z",
     "start_time": "2023-04-13T17:18:51.882725Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPU 0.: Tesla T4\n"
     ]
    }
   ],
   "source": [
    "for i in range(gpu_num):\n",
    "    print('GPU {}.: {}'.format(i, torch.cuda.get_device_name(i)))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "dde8aefb",
   "metadata": {},
   "source": [
    "## Check GPU in TensorFlow\n",
    "### Import TensorFlow to check GPU devices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "10df2c49",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:53.523525Z",
     "start_time": "2023-04-13T17:18:51.882725Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-05-18 03:48:02.355928: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2023-05-18 03:48:03.469790: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "7161c2a2",
   "metadata": {},
   "source": [
    "### Check the version of TensorFlow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "caffacb9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:53.579778Z",
     "start_time": "2023-04-13T17:18:53.526524Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.12.0'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.__version__"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a2a4a346",
   "metadata": {},
   "source": [
    "### Check if TensorFlow can call GPUs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8b273d8a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:56.347645Z",
     "start_time": "2023-04-13T17:18:53.578771Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_3029/337460670.py:1: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.config.list_physical_devices('GPU')` instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-05-18 03:48:05.350918: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "2023-05-18 03:48:05.353979: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "2023-05-18 03:48:05.355430: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "2023-05-18 03:48:08.207251: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "2023-05-18 03:48:08.209036: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "2023-05-18 03:48:08.210484: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "2023-05-18 03:48:08.211886: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1635] Created device /device:GPU:0 with 13167 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:07.0, compute capability: 7.5\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.test.is_gpu_available()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "aa90ea72",
   "metadata": {},
   "source": [
    "### Check physical GPUs in TensorFlow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "16e31ef0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:56.401645Z",
     "start_time": "2023-04-13T17:18:56.363645Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-05-18 03:48:08.223882: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "2023-05-18 03:48:08.225389: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "2023-05-18 03:48:08.226798: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.config.list_physical_devices('GPU')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "9347be86",
   "metadata": {},
   "source": [
    "### Check the GPU type of the computer in TensorFlow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "318c3603",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:56.457653Z",
     "start_time": "2023-04-13T17:18:56.389649Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device: 0, name: Tesla T4, pci bus id: 0000:00:07.0, compute capability: 7.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-05-18 03:48:08.237404: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "2023-05-18 03:48:08.238885: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "2023-05-18 03:48:08.240274: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "2023-05-18 03:48:08.241709: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "2023-05-18 03:48:08.243151: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "2023-05-18 03:48:08.244566: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1635] Created device /device:GPU:0 with 13167 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:07.0, compute capability: 7.5\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.python.client import device_lib\n",
    "\n",
    "local_device_protos = device_lib.list_local_devices()\n",
    "for x in local_device_protos:\n",
    "    if x.device_type == 'GPU':\n",
    "        print(x.physical_device_desc)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "abbbf2db",
   "metadata": {},
   "source": [
    "## Check GPU in Jax and jaxlib\n",
    "### Import Jax and jaxlib to check GPU devices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "374fe8bb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:56.505645Z",
     "start_time": "2023-04-13T17:18:56.419645Z"
    }
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jaxlib"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "16a801b9",
   "metadata": {},
   "source": [
    "### Check the version of jax and jaxlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "fa47eb20",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:56.505645Z",
     "start_time": "2023-04-13T17:18:56.505645Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0.4.1'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "8eda12a5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:56.506646Z",
     "start_time": "2023-04-13T17:18:56.505645Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0.4.1'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jaxlib.__version__"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "423a8016",
   "metadata": {},
   "source": [
    "### Check all the devices from the default backend of jax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "1f643887",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:56.506646Z",
     "start_time": "2023-04-13T17:18:56.505645Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.devices()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "dbcda938",
   "metadata": {},
   "source": [
    "### Check the total number of devices of jax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "f7dbdfc0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:56.506646Z",
     "start_time": "2023-04-13T17:18:56.505645Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.device_count()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "d3b1b71b",
   "metadata": {},
   "source": [
    "### Check the number of JAX processes associated with the backend of jax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c6847ebe",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:56.523880Z",
     "start_time": "2023-04-13T17:18:56.505645Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.process_count()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "6b89568d",
   "metadata": {},
   "source": [
    "### Check the backend of jaxlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "a490cc61",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:56.524833Z",
     "start_time": "2023-04-13T17:18:56.523880Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpu\n"
     ]
    }
   ],
   "source": [
    "from jax.lib import xla_bridge\n",
    "\n",
    "print(xla_bridge.get_backend().platform)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "79e98d71",
   "metadata": {},
   "source": [
    "### Run a linear regression demo to verify that the Jax can run well\n",
    "The demo code is from [https://www.secretflow.org.cn/docs/secretflow/en/tutorial/lr_with_spu.html](https://www.secretflow.org.cn/docs/secretflow/en/tutorial/lr_with_spu.html)\n",
    "#### load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "61ce21f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.datasets import load_breast_cancer\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import Normalizer\n",
    "\n",
    "\n",
    "def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):\n",
    "    x, y = load_breast_cancer(return_X_y=True)\n",
    "    x = (x - np.min(x)) / (np.max(x) - np.min(x))\n",
    "    x_train, x_test, y_train, y_test = train_test_split(\n",
    "        x, y, test_size=0.2, random_state=42\n",
    "    )\n",
    "\n",
    "    if train:\n",
    "        if party_id:\n",
    "            if party_id == 1:\n",
    "                return x_train[:, :15], _\n",
    "            else:\n",
    "                return x_train[:, 15:], y_train\n",
    "        else:\n",
    "            return x_train, y_train\n",
    "    else:\n",
    "        return x_test, y_test"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c5da1e71",
   "metadata": {},
   "source": [
    "#### define model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "9c25b432",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "\n",
    "def sigmoid(x):\n",
    "    return 1 / (1 + jnp.exp(-x))\n",
    "\n",
    "\n",
    "# Outputs probability of a label being true.\n",
    "def predict(W, b, inputs):\n",
    "    return sigmoid(jnp.dot(inputs, W) + b)\n",
    "\n",
    "\n",
    "# Training loss is the negative log-likelihood of the training examples.\n",
    "def loss(W, b, inputs, targets):\n",
    "    preds = predict(W, b, inputs)\n",
    "    label_probs = preds * targets + (1 - preds) * (1 - targets)\n",
    "    return -jnp.mean(jnp.log(label_probs))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ee9c7668",
   "metadata": {},
   "source": [
    "#### build train step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "4cd52139",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax import grad\n",
    "\n",
    "\n",
    "def train_step(W, b, x1, x2, y, learning_rate):\n",
    "    x = jnp.concatenate([x1, x2], axis=1)\n",
    "    Wb_grad = grad(loss, (0, 1))(W, b, x, y)\n",
    "    W -= learning_rate * Wb_grad[0]\n",
    "    b -= learning_rate * Wb_grad[1]\n",
    "    return W, b"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "55cc512d",
   "metadata": {},
   "source": [
    "#### build fit function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "58800872",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit(W, b, x1, x2, y, epochs=1, learning_rate=1e-2):\n",
    "    for _ in range(epochs):\n",
    "        W, b = train_step(W, b, x1, x2, y, learning_rate=learning_rate)\n",
    "    return W, b"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "4175c8ed",
   "metadata": {},
   "source": [
    "#### validate model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "c5fa81ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "\n",
    "def validate_model(W, b, X_test, y_test):\n",
    "    y_pred = predict(W, b, X_test)\n",
    "    return roc_auc_score(y_test, y_pred)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "366f5cb1",
   "metadata": {},
   "source": [
    "#### train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "7094805c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "auc=0.9880445463478545\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "# Load the data\n",
    "x1, _ = breast_cancer(party_id=1, train=True)\n",
    "x2, y = breast_cancer(party_id=2, train=True)\n",
    "\n",
    "# Hyperparameter\n",
    "W = jnp.zeros((30,))\n",
    "b = 0.0\n",
    "epochs = 10\n",
    "learning_rate = 1e-2\n",
    "\n",
    "# Train the model\n",
    "W, b = fit(W, b, x1, x2, y, epochs=10, learning_rate=1e-2)\n",
    "\n",
    "# Validate the model\n",
    "X_test, y_test = breast_cancer(train=False)\n",
    "auc = validate_model(W, b, X_test, y_test)\n",
    "print(f'auc={auc}')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "80242142",
   "metadata": {},
   "source": [
    "As you can see, the warning message\n",
    "\n",
    ">No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
    "\n",
    "doesn't appear, which indicate the jax run the code in GPU"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "3f57b47e",
   "metadata": {},
   "source": [
    "## Import SecretFlow to verify that no errors are reported in this environemnt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "c8a443d6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-13T17:18:56.927007Z",
     "start_time": "2023-04-13T17:18:56.523880Z"
    }
   },
   "outputs": [],
   "source": [
    "import secretflow"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
